{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Adaboost.ipynb",
      "version": "0.3.2",
      "provenance": [],
      "collapsed_sections": []
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.2"
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CGJ1QiK3cnzN",
        "colab_type": "text"
      },
      "source": [
        "# 第8章 提升方法"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "collapsed": true,
        "id": "v_MmENfgcnzN",
        "colab_type": "text"
      },
      "source": [
        "# Boost\n",
        "\n",
        "“装袋”（bagging）和“提升”（boost）是构建组合模型的两种最主要的方法，所谓的组合模型是由多个基本模型构成的模型，组合模型的预测效果往往比任意一个基本模型的效果都要好。\n",
        "\n",
        "- 装袋：每个基本模型由从总体样本中随机抽样得到的不同数据集进行训练得到，通过重抽样得到不同训练数据集的过程称为装袋。\n",
        "\n",
        "- 提升：每个基本模型训练时的数据集采用不同权重，针对上一个基本模型分类错误的样本增加权重，使得新的模型重点关注误分类样本\n",
        "\n",
        "### AdaBoost\n",
        "\n",
        "AdaBoost是AdaptiveBoost的缩写，表明该算法是具有适应性的提升算法。\n",
        "\n",
        "算法的步骤如下：\n",
        "\n",
        "1）给每个训练样本（$x_{1},x_{2},….,x_{N}$）分配权重，初始权重$w_{1}$均为1/N。\n",
        "\n",
        "2）针对带有权值的样本进行训练，得到模型$G_m$（初始模型为G1）。\n",
        "\n",
        "3）计算模型$G_m$的误分率$e_m=\\sum_{i=1}^Nw_iI(y_i\\not= G_m(x_i))$\n",
        "\n",
        "4）计算模型$G_m$的系数$\\alpha_m=0.5\\log[(1-e_m)/e_m]$\n",
        "\n",
        "5）根据误分率e和当前权重向量$w_m$更新权重向量$w_{m+1}$。\n",
        "\n",
        "6）计算组合模型$f(x)=\\sum_{m=1}^M\\alpha_mG_m(x_i)$的误分率。\n",
        "\n",
        "7）当组合模型的误分率或迭代次数低于一定阈值，停止迭代；否则，回到步骤2）"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WkmYWWexcnzO",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import numpy as np\n",
        "import pandas as pd\n",
        "from sklearn.datasets import load_iris\n",
        "from sklearn.tree import DecisionTreeClassifier\n",
        "from sklearn.model_selection  import train_test_split\n",
        "import matplotlib.pyplot as plt\n",
        "%matplotlib inline"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kWFOcuTKcnzR",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# data\n",
        "def create_data():\n",
        "    iris = load_iris()\n",
        "    df = pd.DataFrame(iris.data, columns=iris.feature_names)\n",
        "    df['label'] = iris.target\n",
        "    df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']\n",
        "    data = np.array(df.iloc[:100, [0, 1, -1]])\n",
        "    for i in range(len(data)):\n",
        "        if data[i,-1] == 0:\n",
        "            data[i,-1] = -1\n",
        "    # print(data)\n",
        "    return data[:,:2], data[:,-1]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "uk2Mg38UcnzT",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "X, y = create_data()\n",
        "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FNCiiDMycnzW",
        "colab_type": "code",
        "outputId": "abb8a27d-9db0-449e-e78e-b019f70c2586",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 287
        }
      },
      "source": [
        "plt.scatter(X[:50,0],X[:50,1], label='0')\n",
        "plt.scatter(X[50:,0],X[50:,1], label='1')\n",
        "plt.legend()"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<matplotlib.legend.Legend at 0x7ff301ceac50>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 8
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAGZhJREFUeJzt3X+MXWWdx/H3d4dZOqvQCWVUmCk7\naE2jQNfCCJJuiAtxq7WWBtlS4q8qa3cNLhhcjBiC2piAS4LKkmgqZAFhi92K5cdCWQISf0RqpoDt\n2kpEQTsDuwyDLbIWaMfv/nHvtDO3M3Pvc+89c5/nuZ9X0sycc0/PfJ9z4Ns753zOc83dERGRvPxZ\nqwsQEZHmU3MXEcmQmruISIbU3EVEMqTmLiKSITV3EZEMqbmLiGRIzV1EJENq7iIiGTqi1g3NrAMY\nBIbdfXnFa2uAa4Hh8qob3P3GmfZ37LHHen9/f1CxIiLtbtu2bS+4e0+17Wpu7sClwC7g6Gle/667\nf7rWnfX39zM4OBjw40VExMx+W8t2NV2WMbM+4P3AjO/GRUQkDrVec/868DngTzNs80Ez225mm8xs\n/lQbmNlaMxs0s8GRkZHQWkVEpEZVm7uZLQeed/dtM2x2D9Dv7ouAB4FbptrI3de7+4C7D/T0VL1k\nJCIidarlmvsSYIWZLQPmAEeb2W3u/uHxDdx9dML2NwL/0twyRUSaZ//+/QwNDfHKK6+0upRpzZkz\nh76+Pjo7O+v6+1Wbu7tfAVwBYGbvBv55YmMvrz/O3Z8rL66gdONVRCRKQ0NDHHXUUfT392NmrS7n\nMO7O6OgoQ0NDnHjiiXXto+6cu5mtM7MV5cVLzOwXZvZz4BJgTb37FREp2iuvvMK8efOibOwAZsa8\nefMa+s0iJAqJuz8CPFL+/qoJ6w++uxfJzebHh7n2gSd5ds8+ju/u4vKlC1m5uLfVZUmDYm3s4xqt\nL6i5i7SbzY8Pc8WdO9i3fwyA4T37uOLOHQBq8BI1TT8gMoNrH3jyYGMft2//GNc+8GSLKpJcbNmy\nhYULF7JgwQKuueaapu9fzV1kBs/u2Re0XqQWY2NjXHzxxdx///3s3LmTDRs2sHPnzqb+DF2WEZnB\n8d1dDE/RyI/v7mpBNdIqzb7v8rOf/YwFCxbw5je/GYDVq1dz11138fa3v71ZJeudu8hMLl+6kK7O\njknrujo7uHzpwhZVJLNt/L7L8J59OIfuu2x+fLjq353O8PAw8+cfepC/r6+P4eH69zcVNXeRGaxc\n3MvV551Cb3cXBvR2d3H1eafoZmobSfW+iy7LiFSxcnGvmnkbK+K+S29vL7t37z64PDQ0RG9vc/8b\n0zt3EZEZTHd/pZH7Lu985zv51a9+xdNPP81rr73GHXfcwYoVK6r/xQBq7iIiMyjivssRRxzBDTfc\nwNKlS3nb297GqlWrOOmkkxotdfLPaOreREQyM35JrtlPKS9btoxly5Y1o8QpqbmLiFSR4n0XXZYR\nEcmQmruISIbU3EVEMqTmLiKSITV3EZEMqblLNjY/PsySax7mxM//J0uuebihuT9EivaJT3yCN7zh\nDZx88smF7F/NXbJQxOROIkVas2YNW7ZsKWz/au6ShVQnd5JEbN8IXzsZvtRd+rp9Y8O7POusszjm\nmGOaUNzU9BCTZEEfqiGF2b4R7rkE9pf/W9q7u7QMsGhV6+qqQu/cJQtFTO4kAsBD6w419nH795XW\nR0zNXbKgD9WQwuwdClsfCV2WkSwUNbmTCHP7SpdiplofMTV3yUaKkztJAs65avI1d4DOrtL6Blx4\n4YU88sgjvPDCC/T19fHlL3+Ziy66qMFiD1Fzl4Y1+8ODRaIyftP0oXWlSzFz+0qNvcGbqRs2bGhC\ncdNTc5eGjOfLx2OI4/lyQA1e8rFoVdTJmKnohqo0RPlykTipuUtDlC+XVLl7q0uYUaP1qblLQ5Qv\nlxTNmTOH0dHRaBu8uzM6OsqcOXPq3oeuuUtDLl+6cNI1d1C+XOLX19fH0NAQIyMjrS5lWnPmzKGv\nr/64pZq7NET5cklRZ2cnJ554YqvLKFTNzd3MOoBBYNjdl1e8diRwK3AaMApc4O7PNLFOiZjy5SLx\nCXnnfimwCzh6itcuAn7v7gvMbDXwVeCCJtQnkhRl/iUWNd1QNbM+4P3AjdNsci5wS/n7TcA5ZmaN\nlyeSDs0pLzGpNS3zdeBzwJ+meb0X2A3g7geAvcC8hqsTSYgy/xKTqs3dzJYDz7v7tkZ/mJmtNbNB\nMxuM+S61SD2U+ZeY1PLOfQmwwsyeAe4Azjaz2yq2GQbmA5jZEcBcSjdWJ3H39e4+4O4DPT09DRUu\nEhtl/iUmVZu7u1/h7n3u3g+sBh529w9XbHY38LHy9+eXt4nz6QCRgmhOeYlJ3Tl3M1sHDLr73cBN\nwHfM7CngRUr/CIi0FWX+JSbWqjfYAwMDPjg42JKfLSKSKjPb5u4D1bbTE6oSrSs372DD1t2MudNh\nxoVnzOcrK09pdVkiSVBzlyhduXkHtz36u4PLY+4Hl9XgRarTrJASpQ1bp/jMyhnWi8hkau4SpbFp\n7gVNt15EJlNzlyh1TDN7xXTrRWQyNXeJ0oVnzA9aLyKT6YaqRGn8pqnSMiL1Uc5dRCQhyrlLQz70\n7Z/yk1+/eHB5yVuO4fZPntnCilpHc7RLinTNXQ5T2dgBfvLrF/nQt3/aoopaR3O0S6rU3OUwlY29\n2vqcaY52SZWau8gMNEe7pErNXWQGmqNdUqXmLodZ8pZjgtbnTHO0S6rU3OUwt3/yzMMaebumZVYu\n7uXq806ht7sLA3q7u7j6vFOUlpHoKecuIpIQ5dylIUVlu0P2q3y5SP3U3OUw49nu8QjgeLYbaKi5\nhuy3qBpE2oWuucthisp2h+xX+XKRxqi5y2GKynaH7Ff5cpHGqLnLYYrKdofsV/lykcaoucthisp2\nh+xX+XKRxuiGqhxm/IZls5MqIfstqgaRdqGcu4hIQpRzL1gMGezQGmKoWURmh5p7HWLIYIfWEEPN\nIjJ7dEO1DjFksENriKFmEZk9au51iCGDHVpDDDWLyOxRc69DDBns0BpiqFlEZo+aex1iyGCH1hBD\nzSIye3RDtQ4xZLBDa4ihZhGZPVVz7mY2B/ghcCSlfww2ufsXK7ZZA1wLjH8k/A3ufuNM+1XOXUQk\nXDNz7q8CZ7v7y2bWCfzYzO5390crtvuuu3+6nmJldly5eQcbtu5mzJ0OMy48Yz5fWXlKw9vGkp+P\npQ6RGFRt7l56a/9yebGz/Kc1j7VK3a7cvIPbHv3dweUx94PLlU07ZNtY8vOx1CESi5puqJpZh5k9\nATwPPOjuW6fY7INmtt3MNpnZ/KZWKQ3bsHV3zetDto0lPx9LHSKxqKm5u/uYu78D6ANON7OTKza5\nB+h390XAg8AtU+3HzNaa2aCZDY6MjDRStwQam+beylTrQ7aNJT8fSx0isQiKQrr7HuAHwHsr1o+6\n+6vlxRuB06b5++vdfcDdB3p6euqpV+rUYVbz+pBtY8nPx1KHSCyqNncz6zGz7vL3XcB7gF9WbHPc\nhMUVwK5mFimNu/CMqa+UTbU+ZNtY8vOx1CESi1rSMscBt5hZB6V/DDa6+71mtg4YdPe7gUvMbAVw\nAHgRWFNUwVKf8RuhtSRgQraNJT8fSx0isdB87iIiCdF87gUrKlMdki8vct8h40vxWCRn+0Z4aB3s\nHYK5fXDOVbBoVaurkoipudehqEx1SL68yH2HjC/FY5Gc7Rvhnktgfzn5s3d3aRnU4GVamjisDkVl\nqkPy5UXuO2R8KR6L5Dy07lBjH7d/X2m9yDTU3OtQVKY6JF9e5L5DxpfisUjO3qGw9SKoudelqEx1\nSL68yH2HjC/FY5GcuX1h60VQc69LUZnqkHx5kfsOGV+KxyI551wFnRX/WHZ2ldaLTEM3VOtQVKY6\nJF9e5L5DxpfisUjO+E1TpWUkgHLuIiIJUc5dDhNDdl0Sp7x9MtTc20QM2XVJnPL2SdEN1TYRQ3Zd\nEqe8fVLU3NtEDNl1SZzy9klRc28TMWTXJXHK2ydFzb1NxJBdl8Qpb58U3VBtEzFk1yVxytsnRTl3\nEZGEKOdeVlReO2S/scxLrux6ZHLPjOc+vhAtOBZZN/ei8toh+41lXnJl1yOTe2Y89/GFaNGxyPqG\nalF57ZD9xjIvubLrkck9M577+EK06Fhk3dyLymuH7DeWecmVXY9M7pnx3McXokXHIuvmXlReO2S/\nscxLrux6ZHLPjOc+vhAtOhZZN/ei8toh+41lXnJl1yOTe2Y89/GFaNGxyPqGalF57ZD9xjIvubLr\nkck9M577+EK06Fgo5y4ikhDl3Aum/LxIIu69DLbdDD4G1gGnrYHl1zW+38hz/GrudVB+XiQR914G\ngzcdWvaxQ8uNNPgEcvxZ31AtivLzIonYdnPY+lolkONXc6+D8vMiifCxsPW1SiDHr+ZeB+XnRRJh\nHWHra5VAjl/NvQ7Kz4sk4rQ1YetrlUCOXzdU66D8vEgixm+aNjstk0COXzl3EZGENC3nbmZzgB8C\nR5a33+TuX6zY5kjgVuA0YBS4wN2fqaPuqkLz5anNYR6SXc/9WBSaIw7JPhdVR5HjizyD3ZDQseV8\nLGZQy2WZV4Gz3f1lM+sEfmxm97v7oxO2uQj4vbsvMLPVwFeBC5pdbGi+PLU5zEOy67kfi0JzxCHZ\n56LqKHJ8CWSw6xY6tpyPRRVVb6h6ycvlxc7yn8prOecCt5S/3wScY9b82EZovjy1OcxDsuu5H4tC\nc8Qh2eei6ihyfAlksOsWOracj0UVNaVlzKzDzJ4AngcedPetFZv0ArsB3P0AsBeYN8V+1prZoJkN\njoyMBBcbmi9PbQ7zkOx67sei0BxxSPa5qDqKHF8CGey6hY4t52NRRU3N3d3H3P0dQB9wupmdXM8P\nc/f17j7g7gM9PT3Bfz80X57aHOYh2fXcj0WhOeKQ7HNRdRQ5vgQy2HULHVvOx6KKoJy7u+8BfgC8\nt+KlYWA+gJkdAcyldGO1qULz5anNYR6SXc/9WBSaIw7JPhdVR5HjSyCDXbfQseV8LKqoJS3TA+x3\n9z1m1gW8h9IN04nuBj4G/BQ4H3jYC8hYhubLU5vDPCS7nvuxKDRHHJJ9LqqOIseXQAa7bqFjy/lY\nVFE1525miyjdLO2g9E5/o7uvM7N1wKC7312OS34HWAy8CKx299/MtF/l3EVEwjUt5+7u2yk17cr1\nV034/hXg70KLFBGRYmQ//UByD+7I7Ah5sCWGh2CKfHAntYe0YjgfCci6uSf34I7MjpAHW2J4CKbI\nB3dSe0grhvORiKxnhUzuwR2ZHSEPtsTwEEyRD+6k9pBWDOcjEVk39+Qe3JHZEfJgSwwPwRT54E5q\nD2nFcD4SkXVzT+7BHZkdIQ+2xPAQTJEP7qT2kFYM5yMRWTf35B7ckdkR8mBLDA/BFPngTmoPacVw\nPhKRdXNfubiXq887hd7uLgzo7e7i6vNO0c3UdrdoFXzgepg7H7DS1w9cP/UNuZBtY6g3dPuixpfa\nfjOkD+sQEUlI0x5iEml7IR/sEYvUao4lux5LHU2g5i4yk5AP9ohFajXHkl2PpY4myfqau0jDQj7Y\nIxap1RxLdj2WOppEzV1kJiEf7BGL1GqOJbseSx1NouYuMpOQD/aIRWo1x5Jdj6WOJlFzF5lJyAd7\nxCK1mmPJrsdSR5OouYvMZPl1MHDRoXe91lFajvHG5LjUao4lux5LHU2inLuISEKUc5fZk2I2uKia\ni8qXp3iMpaXU3KUxKWaDi6q5qHx5isdYWk7X3KUxKWaDi6q5qHx5isdYWk7NXRqTYja4qJqLypen\neIyl5dTcpTEpZoOLqrmofHmKx1haTs1dGpNiNriomovKl6d4jKXl1NylMSlmg4uquah8eYrHWFpO\nOXcRkYTUmnPXO3fJx/aN8LWT4Uvdpa/bN87+fouqQSSQcu6Sh6Ky4CH7VR5dIqJ37pKHorLgIftV\nHl0iouYueSgqCx6yX+XRJSJq7pKHorLgIftVHl0iouYueSgqCx6yX+XRJSJq7pKHorLgIftVHl0i\nUjXnbmbzgVuBNwIOrHf3b1Rs827gLuDp8qo73X3Gu0jKuYuIhGvmfO4HgM+6+2NmdhSwzcwedPed\nFdv9yN2X11OsRCjF+cNDak5xfDHQcUtG1ebu7s8Bz5W//4OZ7QJ6gcrmLrlIMa+tPHrxdNySEnTN\n3cz6gcXA1ilePtPMfm5m95vZSU2oTVolxby28ujF03FLSs1PqJrZ64HvAZ9x95cqXn4M+Et3f9nM\nlgGbgbdOsY+1wFqAE044oe6ipWAp5rWVRy+ejltSanrnbmadlBr77e5+Z+Xr7v6Su79c/v4+oNPM\njp1iu/XuPuDuAz09PQ2WLoVJMa+tPHrxdNySUrW5m5kBNwG73H3KuUvN7E3l7TCz08v7HW1moTKL\nUsxrK49ePB23pNRyWWYJ8BFgh5k9UV73BeAEAHf/FnA+8CkzOwDsA1Z7q+YSlsaN3xxLKRURUnOK\n44uBjltSNJ+7iEhCmplzl1gpczzZvZfBtptLH0htHaWPt2v0U5BEEqXmnipljie79zIYvOnQso8d\nWlaDlzakuWVSpczxZNtuDlsvkjk191QpczyZj4WtF8mcmnuqlDmezDrC1otkTs09VcocT3bamrD1\nIplTc0+V5g6fbPl1MHDRoXfq1lFa1s1UaVPKuYuIJEQ59zpsfnyYax94kmf37OP47i4uX7qQlYt7\nW11W8+Sei899fDHQMU6GmnvZ5seHueLOHezbX0pXDO/ZxxV37gDIo8HnnovPfXwx0DFOiq65l137\nwJMHG/u4ffvHuPaBJ1tUUZPlnovPfXwx0DFOipp72bN79gWtT07uufjcxxcDHeOkqLmXHd/dFbQ+\nObnn4nMfXwx0jJOi5l52+dKFdHVOfuClq7ODy5cubFFFTZZ7Lj738cVAxzgpuqFaNn7TNNu0TO5z\ncec+vhjoGCdFOXcRkYTUmnPXZRmRFGzfCF87Gb7UXfq6fWMa+5aW0WUZkdgVmS9Xdj1beucuErsi\n8+XKrmdLzV0kdkXmy5Vdz5aau0jsisyXK7ueLTV3kdgVmS9Xdj1bau4isSty7n59LkC2lHMXEUmI\ncu4iIm1MzV1EJENq7iIiGVJzFxHJkJq7iEiG1NxFRDKk5i4ikiE1dxGRDFVt7mY238x+YGY7zewX\nZnbpFNuYmV1vZk+Z2XYzO7WYcqUhmrdbpG3UMp/7AeCz7v6YmR0FbDOzB91954Rt3ge8tfznDOCb\n5a8SC83bLdJWqr5zd/fn3P2x8vd/AHYBlR8sei5wq5c8CnSb2XFNr1bqp3m7RdpK0DV3M+sHFgNb\nK17qBXZPWB7i8H8AMLO1ZjZoZoMjIyNhlUpjNG+3SFupubmb2euB7wGfcfeX6vlh7r7e3QfcfaCn\np6eeXUi9NG+3SFupqbmbWSelxn67u985xSbDwPwJy33ldRILzdst0lZqScsYcBOwy92vm2azu4GP\nllMz7wL2uvtzTaxTGqV5u0XaSi1pmSXAR4AdZvZEed0XgBMA3P1bwH3AMuAp4I/Ax5tfqjRs0So1\nc5E2UbW5u/uPAauyjQMXN6soERFpjJ5QFRHJkJq7iEiG1NxFRDKk5i4ikiE1dxGRDKm5i4hkSM1d\nRCRDVoqot+AHm40Av23JD6/uWOCFVhdRII0vXTmPDTS+Wvylu1ednKtlzT1mZjbo7gOtrqMoGl+6\nch4baHzNpMsyIiIZUnMXEcmQmvvU1re6gIJpfOnKeWyg8TWNrrmLiGRI79xFRDLU1s3dzDrM7HEz\nu3eK19aY2YiZPVH+8/etqLERZvaMme0o1z84xetmZteb2VNmtt3MTm1FnfWoYWzvNrO9E85fUh85\nZWbdZrbJzH5pZrvM7MyK15M9d1DT+JI9f2a2cELdT5jZS2b2mYptCj9/tXxYR84uBXYBR0/z+nfd\n/dOzWE8R/sbdp8vVvg94a/nPGcA3y19TMdPYAH7k7stnrZrm+gawxd3PN7M/B/6i4vXUz1218UGi\n58/dnwTeAaU3kJQ+cvT7FZsVfv7a9p27mfUB7wdubHUtLXQucKuXPAp0m9lxrS6q3ZnZXOAsSh9v\nibu/5u57KjZL9tzVOL5cnAP82t0rH9gs/Py1bXMHvg58DvjTDNt8sPwr0yYzmz/DdrFy4L/MbJuZ\nrZ3i9V5g94TlofK6FFQbG8CZZvZzM7vfzE6azeIadCIwAvxb+bLhjWb2uoptUj53tYwP0j1/E60G\nNkyxvvDz15bN3cyWA8+7+7YZNrsH6Hf3RcCDwC2zUlxz/bW7n0rpV8CLzeysVhfURNXG9hilx7T/\nCvhXYPNsF9iAI4BTgW+6+2Lg/4DPt7akpqplfCmfPwDKl5tWAP/Rip/fls2d0od+rzCzZ4A7gLPN\n7LaJG7j7qLu/Wl68EThtdktsnLsPl78+T+ma3+kVmwwDE38j6Suvi161sbn7S+7+cvn7+4BOMzt2\n1gutzxAw5O5by8ubKDXDiZI9d9QwvsTP37j3AY+5+/9O8Vrh568tm7u7X+Hufe7eT+nXpofd/cMT\nt6m4/rWC0o3XZJjZ68zsqPHvgb8F/rtis7uBj5bv3L8L2Ovuz81yqcFqGZuZvcnMrPz96ZT+Wx+d\n7Vrr4e7/A+w2s4XlVecAOys2S/LcQW3jS/n8TXAhU1+SgVk4f+2elpnEzNYBg+5+N3CJma0ADgAv\nAmtaWVsd3gh8v/z/xxHAv7v7FjP7RwB3/xZwH7AMeAr4I/DxFtUaqpaxnQ98yswOAPuA1Z7WE3v/\nBNxe/tX+N8DHMzl346qNL+nzV37T8R7gHyasm9XzpydURUQy1JaXZUREcqfmLiKSITV3EZEMqbmL\niGRIzV1EJENq7iIiGVJzFxHJkJq7iEiG/h86qpKOmdh1nwAAAABJRU5ErkJggg==\n",
            "text/plain": [
              "<Figure size 432x288 with 1 Axes>"
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9AlbKsaNhY-J",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# weak classifier\n",
        "weak_cla = DecisionTreeClassifier(max_depth=1)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "AS7dimPwhp2r",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 125
        },
        "outputId": "3f3c28e2-7fa9-437f-db5c-da74f6acb084"
      },
      "source": [
        "# fit\n",
        "weak_cla.fit(X_train, y_train)"
      ],
      "execution_count": 46,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=1,\n",
              "                       max_features=None, max_leaf_nodes=None,\n",
              "                       min_impurity_decrease=0.0, min_impurity_split=None,\n",
              "                       min_samples_leaf=1, min_samples_split=2,\n",
              "                       min_weight_fraction_leaf=0.0, presort=False,\n",
              "                       random_state=None, splitter='best')"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 46
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xOWWIY5ph00J",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "outputId": "088e2e42-4db2-46e7-e76f-d7b2d048327c"
      },
      "source": [
        "weak_cla_accuracy = weak_cla.score(X_test, y_test);weak_cla_accuracy"
      ],
      "execution_count": 47,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0.85"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 47
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9ezbeV0Tcnza",
        "colab_type": "text"
      },
      "source": [
        "----\n",
        "\n",
        "### AdaBoost\n",
        "算法 8.1"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5Ftb2TtQcnzb",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class AdaBoost:\n",
        "    def __init__(self, n_estimators=100):\n",
        "        self.clf_num = n_estimators\n",
        "    \n",
        "    def init_args(self, X, y):\n",
        "        \n",
        "        self.X = X\n",
        "        self.y = y\n",
        "        M, _ = X.shape\n",
        "        \n",
        "        self.models = []\n",
        "        self.alphas = []\n",
        "        self.weights = np.ones(M) / M # 1\n",
        "    \n",
        "    def fit(self, X, y):\n",
        "        self.init_args(X, y)\n",
        "        \n",
        "        for n in range(self.clf_num):\n",
        "            cla = DecisionTreeClassifier(max_depth=1) # weak cla\n",
        "            cla.fit(X, y, sample_weight=self.weights) # 2(a)\n",
        "            P = cla.predict(X) \n",
        "            \n",
        "            err = self.weights.dot(P != y) # 2(b) 8.1\n",
        "            alpha = 0.5*(np.log(1 - err) - np.log(err)) # 2(c) 8.2\n",
        "            \n",
        "            self.weights = self.weights * np.exp(-alpha * y * P)\n",
        "            self.weights = self.weights / self.weights.sum() # 2(d) 8.3, 8.4, 8.5\n",
        "            \n",
        "            self.models.append(cla)\n",
        "            self.alphas.append(alpha)\n",
        "            \n",
        "        return 'Done!'\n",
        "    \n",
        "    def predict(self, x):\n",
        "        N, _ = x.shape\n",
        "        FX = np.zeros(N)\n",
        "        \n",
        "        for alpha, cla in zip(self.alphas, self.models):\n",
        "            FX += alpha * cla.predict(x)\n",
        "\n",
        "        return np.sign(FX)\n",
        "    \n",
        "    def score(self, X_test, y_test):\n",
        "        p = self.predict(X_test)\n",
        "        r = np.sum(p == y_test)\n",
        "        \n",
        "        return r/len(X_test)\n",
        "    \n",
        "    def _weights(self):\n",
        "        return self.alphas, self.weights, self.models"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "onbrb8Xsr5DE",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "adaboost = AdaBoost()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kae_0imhr-Ld",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "outputId": "e6dc8e11-3ede-4ead-e35a-1ba25b26e056"
      },
      "source": [
        "adaboost.fit(X_train, y_train)"
      ],
      "execution_count": 112,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'Done!'"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 112
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bhWzbeAEsl64",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "outputId": "0944a671-36b0-4347-9da4-271ce20315be"
      },
      "source": [
        "adaboost.score(X_test, y_test)"
      ],
      "execution_count": 113,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "1.0"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 113
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VQhHJ_rDcnzd",
        "colab_type": "text"
      },
      "source": [
        "### 例8.1"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qEKxLH3Hcnzd",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "X_ = np.arange(10).reshape(10, 1)\n",
        "y_ = np.array([1, 1, 1, -1, -1, -1, 1, 1, 1, -1])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "iGr8lCCicnzg",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "outputId": "c1bffec1-902d-4076-e4f4-0318b9acfddd"
      },
      "source": [
        "clf = AdaBoost()\n",
        "clf.fit(X_, y_)"
      ],
      "execution_count": 115,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'Done!'"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 115
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YTbvHicmcnzq",
        "colab_type": "text"
      },
      "source": [
        "-----\n",
        "# sklearn.ensemble.AdaBoostClassifier\n",
        "\n",
        "- algorithm：这个参数只有AdaBoostClassifier有。主要原因是scikit-learn实现了两种Adaboost分类算法，SAMME和SAMME.R。两者的主要区别是弱学习器权重的度量，SAMME使用了和我们的原理篇里二元分类Adaboost算法的扩展，即用对样本集分类效果作为弱学习器权重，而SAMME.R使用了对样本集分类的预测概率大小来作为弱学习器权重。由于SAMME.R使用了概率度量的连续值，迭代一般比SAMME快，因此AdaBoostClassifier的默认算法algorithm的值也是SAMME.R。我们一般使用默认的SAMME.R就够了，但是要注意的是使用了SAMME.R， 则弱分类学习器参数base_estimator必须限制使用支持概率预测的分类器。SAMME算法则没有这个限制。\n",
        "\n",
        "- n_estimators： AdaBoostClassifier和AdaBoostRegressor都有，就是我们的弱学习器的最大迭代次数，或者说最大的弱学习器的个数。一般来说n_estimators太小，容易欠拟合，n_estimators太大，又容易过拟合，一般选择一个适中的数值。默认是50。在实际调参的过程中，我们常常将n_estimators和下面介绍的参数learning_rate一起考虑。\n",
        "\n",
        "-  learning_rate:  AdaBoostClassifier和AdaBoostRegressor都有，即每个弱学习器的权重缩减系数ν\n",
        "\n",
        "- base_estimator：AdaBoostClassifier和AdaBoostRegressor都有，即我们的弱分类学习器或者弱回归学习器。理论上可以选择任何一个分类或者回归学习器，不过需要支持样本权重。我们常用的一般是CART决策树或者神经网络MLP。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "CGLto18Ycnzr",
        "colab_type": "code",
        "outputId": "4c41d66b-b820-4dda-a212-9d18db023e1b",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 53
        }
      },
      "source": [
        "from sklearn.ensemble import AdaBoostClassifier\n",
        "clf = AdaBoostClassifier(n_estimators=100, learning_rate=0.5)\n",
        "clf.fit(X_train, y_train)"
      ],
      "execution_count": 86,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "AdaBoostClassifier(algorithm='SAMME.R', base_estimator=None, learning_rate=0.5,\n",
              "                   n_estimators=100, random_state=None)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 86
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XHy-rTRwcnzu",
        "colab_type": "code",
        "outputId": "41b0c4e7-100f-4b18-f25d-a15e33d6060e",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        }
      },
      "source": [
        "clf.score(X_test, y_test)"
      ],
      "execution_count": 87,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "1.0"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 87
        }
      ]
    }
  ]
}